{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 扫描和循环 Kernel\n",
        "\n",
        "**原作者**: [Tianqi Chen](https://tqchen.github.io)\n",
        "\n",
        "这是关于如何在 TVM 中进行循环计算的介绍材料。\n",
        "\n",
        "循环计算是神经网络中的一种典型模式。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\n",
        "from tvm import te\n",
        "import numpy as np"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "TVM 支持 `scan` 算子来描述符号循环。\n",
        "下面的 `scan` op 计算 X 列的 cumsum。\n",
        "\n",
        "scan 在张量的最高维度上进行。`s_state` 是一个占位符，描述 scan 的变换状态。`s_init` 描述了如何初始化前 k 个时间步（timestep）。这里由于 `s_init` 的第一个维度是 1，它描述了如何在第一个时间步初始化状态。\n",
        "\n",
        "`s_update` 描述了如何在时间步骤 t 更新值。值可以通过状态占位符引用回前一个时间步的值。注意，在当前或后续的时间步引用 `s_state` 是无效的。\n",
        "\n",
        "扫描包含状态占位符、初始值和更新描述。还建议（尽管不是必需的）列出 scan cell 的输入。\n",
        "扫描的结果是张量，在时域更新后给出 `s_state` 的结果。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "m = te.var(\"m\")\n",
        "n = te.var(\"n\")\n",
        "X = te.placeholder((m, n), name=\"X\")\n",
        "s_state = te.placeholder((m, n))\n",
        "s_init = te.compute((1, n), lambda _, i: X[0, i])\n",
        "s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])\n",
        "s_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X])"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 调度 Scan Cell\n",
        "\n",
        "可以通过分别调度更新和初始化部分来调度扫描主体（body）。注意，调度更新部分的第一个迭代维度是无效的。要在时间迭代上进行分割，用户可以使用 `scan_op.scan_axis` 代替。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(X: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle, scan: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "            }\n",
              "        )\n",
              "        m <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        n <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        stride <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        stride_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        X_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(X, (m, n), strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>(stride, stride_1), type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "        stride_2 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        stride_3 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        scan_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(scan, (m, n), strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>(stride_2, stride_3), type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "        blockIdx_x <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>env_thread(<span style=\"color: #BA2121\">&quot;blockIdx.x&quot;</span>)\n",
              "        threadIdx_x <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>env_thread(<span style=\"color: #BA2121\">&quot;threadIdx.x&quot;</span>)\n",
              "        scan_2 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((stride_2 <span style=\"color: #AA22FF; font-weight: bold\">*</span> m,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>scan_1<span style=\"color: #AA22FF; font-weight: bold\">.</span>data, type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "        X_2 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((stride <span style=\"color: #AA22FF; font-weight: bold\">*</span> m,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>X_1<span style=\"color: #AA22FF; font-weight: bold\">.</span>data, type<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;auto&quot;</span>)\n",
              "        <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>launch_thread(blockIdx_x, (n <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">255</span>) <span style=\"color: #AA22FF; font-weight: bold\">//</span> <span style=\"color: #008000\">256</span>):\n",
              "            T<span style=\"color: #AA22FF; font-weight: bold\">.</span>launch_thread(threadIdx_x, <span style=\"color: #008000\">256</span>)\n",
              "            <span style=\"color: #008000; font-weight: bold\">if</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>likely(blockIdx_x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">256</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> threadIdx_x <span style=\"color: #AA22FF; font-weight: bold\">&lt;</span> n):\n",
              "                scan_2[(blockIdx_x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">256</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> threadIdx_x) <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride_3] <span style=\"color: #AA22FF; font-weight: bold\">=</span> X_2[\n",
              "                    (blockIdx_x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">256</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> threadIdx_x) <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride_1\n",
              "                ]\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> scan_idx <span style=\"color: #008000; font-weight: bold\">in</span> range(m <span style=\"color: #AA22FF; font-weight: bold\">-</span> <span style=\"color: #008000\">1</span>):\n",
              "            T<span style=\"color: #AA22FF; font-weight: bold\">.</span>launch_thread(blockIdx_x, (n <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">255</span>) <span style=\"color: #AA22FF; font-weight: bold\">//</span> <span style=\"color: #008000\">256</span>)\n",
              "            T<span style=\"color: #AA22FF; font-weight: bold\">.</span>launch_thread(threadIdx_x, <span style=\"color: #008000\">256</span>)\n",
              "            <span style=\"color: #008000; font-weight: bold\">if</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>likely(blockIdx_x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">256</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> threadIdx_x <span style=\"color: #AA22FF; font-weight: bold\">&lt;</span> n):\n",
              "                cse_var_1: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> scan_idx <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">1</span>\n",
              "                scan_2[\n",
              "                    cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> (blockIdx_x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">256</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> threadIdx_x) <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride_3\n",
              "                ] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                    scan_2[\n",
              "                        scan_idx <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride_2\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> (blockIdx_x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">256</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> threadIdx_x) <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride_3\n",
              "                    ]\n",
              "                    <span style=\"color: #AA22FF; font-weight: bold\">+</span> X_2[\n",
              "                        cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride <span style=\"color: #AA22FF; font-weight: bold\">+</span> (blockIdx_x <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">256</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> threadIdx_x) <span style=\"color: #AA22FF; font-weight: bold\">*</span> stride_1\n",
              "                    ]\n",
              "                )\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "s = te.create_schedule(s_scan.op)\n",
        "num_thread = 256\n",
        "block_x = te.thread_axis(\"blockIdx.x\")\n",
        "thread_x = te.thread_axis(\"threadIdx.x\")\n",
        "xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)\n",
        "s[s_init].bind(xo, block_x)\n",
        "s[s_init].bind(xi, thread_x)\n",
        "xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)\n",
        "s[s_update].bind(xo, block_x)\n",
        "s[s_update].bind(xi, thread_x)\n",
        "tvm.lower(s, [X, s_scan], simple_mode=True).show()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 构建并验证\n",
        "\n",
        "可以像其他 TVM 内核一样构建扫描内核，这里使用 numpy 来验证结果的正确性。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fscan = tvm.build(s, [X, s_scan], \"cuda\", name=\"myscan\")\n",
        "dev = tvm.cuda(0)\n",
        "n = 1024\n",
        "m = 10\n",
        "a_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype)\n",
        "a = tvm.nd.array(a_np, dev)\n",
        "b = tvm.nd.array(np.zeros((m, n), dtype=s_scan.dtype), dev)\n",
        "fscan(a, b)\n",
        "np.testing.assert_allclose(b.numpy(), np.cumsum(a_np, axis=0))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Multi-Stage Scan Cell\n",
        "In the above example we described the scan cell using one Tensor\n",
        "computation stage in s_update. It is possible to use multiple\n",
        "Tensor stages in the scan cell.\n",
        "\n",
        "The following lines demonstrate a scan with two stage operations\n",
        "in the scan cell.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "m = te.var(\"m\")\n",
        "n = te.var(\"n\")\n",
        "X = te.placeholder((m, n), name=\"X\")\n",
        "s_state = te.placeholder((m, n))\n",
        "s_init = te.compute((1, n), lambda _, i: X[0, i])\n",
        "s_update_s1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] * 2, name=\"s1\")\n",
        "s_update_s2 = te.compute((m, n), lambda t, i: s_update_s1[t, i] + X[t, i], name=\"s2\")\n",
        "s_scan = tvm.te.scan(s_init, s_update_s2, s_state, inputs=[X])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "These intermediate tensors can also be scheduled normally.\n",
        "To ensure correctness, TVM creates a group constraint to forbid\n",
        "the body of scan to be compute_at locations outside the scan loop.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "s = te.create_schedule(s_scan.op)\n",
        "xo, xi = s[s_update_s2].split(s_update_s2.op.axis[1], factor=32)\n",
        "s[s_update_s1].compute_at(s[s_update_s2], xo)\n",
        "print(tvm.lower(s, [X, s_scan], simple_mode=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Multiple States\n",
        "For complicated applications like RNN, we might need more than one\n",
        "recurrent state. Scan support multiple recurrent states.\n",
        "The following example demonstrates how we can build recurrence with two states.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "m = te.var(\"m\")\n",
        "n = te.var(\"n\")\n",
        "l = te.var(\"l\")\n",
        "X = te.placeholder((m, n), name=\"X\")\n",
        "s_state1 = te.placeholder((m, n))\n",
        "s_state2 = te.placeholder((m, l))\n",
        "s_init1 = te.compute((1, n), lambda _, i: X[0, i])\n",
        "s_init2 = te.compute((1, l), lambda _, i: 0.0)\n",
        "s_update1 = te.compute((m, n), lambda t, i: s_state1[t - 1, i] + X[t, i])\n",
        "s_update2 = te.compute((m, l), lambda t, i: s_state2[t - 1, i] + s_state1[t - 1, 0])\n",
        "s_scan1, s_scan2 = tvm.te.scan(\n",
        "    [s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2], inputs=[X]\n",
        ")\n",
        "s = te.create_schedule(s_scan1.op)\n",
        "print(tvm.lower(s, [X, s_scan1, s_scan2], simple_mode=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Summary\n",
        "This tutorial provides a walk through of scan primitive.\n",
        "\n",
        "- Describe scan with init and update.\n",
        "- Schedule the scan cells as normal schedule.\n",
        "- For complicated workload, use multiple states and steps in scan cell.\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
